import unittest

import torch
import torch.nn.functional as F

SKIP_TEST = None
try:
    from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func
except ImportError as e:
    SKIP_TEST = e


@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
class FusedSoftmaxTest(unittest.TestCase):
    def setUp(self, seed=1234):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        self.seq_length = 80
        self.sequences = 10
        self.hidden_dim = 1024
        self.heads = 16
        self.dropout_prob = 0.0

        self.mask = (torch.randn(self.sequences, self.seq_length) > 0).cuda()
        self.mask = self.mask.half() * -10000
        self.ref_inputs = torch.randn(
            self.heads * self.sequences,
            self.seq_length,
            self.seq_length,
            dtype=torch.float16,
            device=torch.device("cuda"),
        ).requires_grad_(True)

        self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)

    def test_fused_softmax(self):
        grads = torch.randn_like(self.tst_inputs)
        y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)
        y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)
        y_ref = y_ref.view(self.sequences * self.heads, self.seq_length, self.seq_length)
        y_ref = F.softmax(y_ref, dim=-1)
        y_ref = torch._fused_dropout(y_ref, 1.0)

        y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0)
        y_ref[0].backward(grads)
        y_tst.backward(grads)

        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(y_ref[0], y_tst, atol=1e-3, rtol=1e-3)
        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
    unittest.main()
